# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loss functions to be used by LayerCollection."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc

import six

from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bernoulli
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.ops.distributions import normal


@six.add_metaclass(abc.ABCMeta)
class LossFunction(object):
  """Abstract base class for loss functions.

  Note that unlike typical loss functions used in neural networks these are
  summed and not averaged across cases in the batch, since this is what the
  users of this class (FisherEstimator and MatrixVectorProductComputer) will
  be expecting. The implication of this is that you will may want to
  normalize things like Fisher-vector products by the batch size when you
  use this class.  It depends on the use case.
  """

  @abc.abstractproperty
  def targets(self):
    """The targets being predicted by the model.

    Returns:
      None or Tensor of appropriate shape for calling self._evaluate() on.
    """
    pass

  @abc.abstractproperty
  def inputs(self):
    """The inputs to the loss function (excluding the targets)."""
    pass

  @property
  def input_minibatches(self):
    """A `list` of inputs to the loss function, separated by minibatch.

    Typically there will be one minibatch per tower in a multi-tower setup.
    Returns a list consisting of `self.inputs` by default; `LossFunction`s
    supporting registering multiple minibatches should override this method.

    Returns:
      A `list` of `Tensor`s representing
    """
    return [self.inputs]

  @property
  def num_registered_minibatches(self):
    """Number of minibatches registered for this LossFunction.

    Typically equal to the number of towers in a multi-tower setup.

    Returns:
      An `int` representing the number of registered minibatches.
    """
    return len(self.input_minibatches)

  def evaluate(self):
    """Evaluate the loss function on the targets."""
    if self.targets is not None:
      # We treat the targets as "constant".  It's only the inputs that get
      # "back-propped" through.
      return self._evaluate(array_ops.stop_gradient(self.targets))
    else:
      raise Exception("Cannot evaluate losses with unspecified targets.")

  @abc.abstractmethod
  def _evaluate(self, targets):
    """Evaluates the log probability of the targets.

    Args:
      targets: Tensor that distribution can calculate log_prob() of.

    Returns:
      log probability of each target, summed across all targets.
    """
    pass

  @abc.abstractmethod
  def multiply_hessian(self, vector):
    """Right-multiply a vector by the Hessian.

    Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
    of the loss function with respect to its inputs.

    Args:
      vector: The vector to multiply.  Must be the same shape(s) as the
        'inputs' property.

    Returns:
      The vector right-multiplied by the Hessian.  Will be of the same shape(s)
      as the 'inputs' property.
    """
    pass

  @abc.abstractmethod
  def multiply_hessian_factor(self, vector):
    """Right-multiply a vector by a factor B of the Hessian.

    Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
    of the loss function with respect to its inputs.  Typically this will be
    block-diagonal across different cases in the batch, since the loss function
    is typically summed across cases.

    Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
    but will agree with the one used in the other methods of this class.

    Args:
      vector: The vector to multiply.  Must be of the shape given by the
        'hessian_factor_inner_shape' property.

    Returns:
      The vector right-multiplied by B.  Will be of the same shape(s) as the
      'inputs' property.
    """
    pass

  @abc.abstractmethod
  def multiply_hessian_factor_transpose(self, vector):
    """Right-multiply a vector by the transpose of a factor B of the Hessian.

    Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
    of the loss function with respect to its inputs.  Typically this will be
    block-diagonal across different cases in the batch, since the loss function
    is typically summed across cases.

    Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
    but will agree with the one used in the other methods of this class.

    Args:
      vector: The vector to multiply.  Must be the same shape(s) as the
        'inputs' property.

    Returns:
      The vector right-multiplied by B^T.  Will be of the shape given by the
      'hessian_factor_inner_shape' property.
    """
    pass

  @abc.abstractmethod
  def multiply_hessian_factor_replicated_one_hot(self, index):
    """Right-multiply a replicated-one-hot vector by a factor B of the Hessian.

    Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
    of the loss function with respect to its inputs.  Typically this will be
    block-diagonal across different cases in the batch, since the loss function
    is typically summed across cases.

    A 'replicated-one-hot' vector means a tensor which, for each slice along the
    batch dimension (assumed to be dimension 0), is 1.0 in the entry
    corresponding to the given index and 0 elsewhere.

    Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
    but will agree with the one used in the other methods of this class.

    Args:
      index: A tuple representing in the index of the entry in each slice that
        is 1.0. Note that len(index) must be equal to the number of elements
        of the 'hessian_factor_inner_shape' tensor minus one.

    Returns:
      The vector right-multiplied by B^T. Will be of the same shape(s) as the
      'inputs' property.
    """
    pass

  @abc.abstractproperty
  def hessian_factor_inner_shape(self):
    """The shape of the tensor returned by multiply_hessian_factor."""
    pass

  @abc.abstractproperty
  def hessian_factor_inner_static_shape(self):
    """Static version of hessian_factor_inner_shape."""
    pass


@six.add_metaclass(abc.ABCMeta)
class NegativeLogProbLoss(LossFunction):
  """Abstract base class for loss functions that are negative log probs."""

  def __init__(self, seed=None):
    self._default_seed = seed
    super(NegativeLogProbLoss, self).__init__()

  @property
  def inputs(self):
    return self.params

  @abc.abstractproperty
  def params(self):
    """Parameters to the underlying distribution."""
    pass

  @abc.abstractmethod
  def multiply_fisher(self, vector):
    """Right-multiply a vector by the Fisher.

    Args:
      vector: The vector to multiply.  Must be the same shape(s) as the
        'inputs' property.

    Returns:
      The vector right-multiplied by the Fisher.  Will be of the same shape(s)
      as the 'inputs' property.
    """
    pass

  @abc.abstractmethod
  def multiply_fisher_factor(self, vector):
    """Right-multiply a vector by a factor B of the Fisher.

    Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
    product of gradients) with respect to the parameters of the underlying
    probability distribtion (whose log-prob defines the loss). Typically this
    will be block-diagonal across different cases in the batch, since the
    distribution is usually (but not always) conditionally iid across different
    cases.

    Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
    but will agree with the one used in the other methods of this class.

    Args:
      vector: The vector to multiply.  Must be of the shape given by the
        'fisher_factor_inner_shape' property.

    Returns:
      The vector right-multiplied by B. Will be of the same shape(s) as the
      'inputs' property.
    """
    pass

  @abc.abstractmethod
  def multiply_fisher_factor_transpose(self, vector):
    """Right-multiply a vector by the transpose of a factor B of the Fisher.

    Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
    product of gradients) with respect to the parameters of the underlying
    probability distribtion (whose log-prob defines the loss). Typically this
    will be block-diagonal across different cases in the batch, since the
    distribution is usually (but not always) conditionally iid across different
    cases.

    Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
    but will agree with the one used in the other methods of this class.

    Args:
      vector: The vector to multiply.  Must be the same shape(s) as the
        'inputs' property.

    Returns:
      The vector right-multiplied by B^T.  Will be of the shape given by the
      'fisher_factor_inner_shape' property.
    """
    pass

  @abc.abstractmethod
  def multiply_fisher_factor_replicated_one_hot(self, index):
    """Right-multiply a replicated-one-hot vector by a factor B of the Fisher.

    Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
    product of gradients) with respect to the parameters of the underlying
    probability distribtion (whose log-prob defines the loss). Typically this
    will be block-diagonal across different cases in the batch, since the
    distribution is usually (but not always) conditionally iid across different
    cases.

    A 'replicated-one-hot' vector means a tensor which, for each slice along the
    batch dimension (assumed to be dimension 0), is 1.0 in the entry
    corresponding to the given index and 0 elsewhere.

    Note that B can be any matrix satisfying B * B^T = H where H is the Fisher,
    but will agree with the one used in the other methods of this class.

    Args:
      index: A tuple representing in the index of the entry in each slice that
        is 1.0. Note that len(index) must be equal to the number of elements
        of the 'fisher_factor_inner_shape' tensor minus one.

    Returns:
      The vector right-multiplied by B. Will be of the same shape(s) as the
      'inputs' property.
    """
    pass

  @abc.abstractproperty
  def fisher_factor_inner_shape(self):
    """The shape of the tensor returned by multiply_fisher_factor."""
    pass

  @abc.abstractproperty
  def fisher_factor_inner_static_shape(self):
    """Static version of fisher_factor_inner_shape."""
    pass

  @abc.abstractmethod
  def sample(self, seed):
    """Sample 'targets' from the underlying distribution."""
    pass

  def evaluate_on_sample(self, seed=None):
    """Evaluates the log probability on a random sample.

    Args:
      seed: int or None. Random seed for this draw from the distribution.

    Returns:
      Log probability of sampled targets, summed across examples.
    """
    if seed is None:
      seed = self._default_seed
    # We treat the targets as "constant".  It's only the inputs that get
    # "back-propped" through.
    return self._evaluate(array_ops.stop_gradient(self.sample(seed)))


# TODO(jamesmartens): should this just inherit from object to avoid "diamond"
# inheritance, or is there a better way?
class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss):
  """Base class for neg log prob losses whose inputs are 'natural' parameters.

  Note that the Hessian and Fisher for natural parameters of exponential-
  family models are the same, hence the purpose of this class.
  See here: https://arxiv.org/abs/1412.1193

  'Natural parameters' are defined for exponential-family models. See for
  example: https://en.wikipedia.org/wiki/Exponential_family
  """

  def multiply_hessian(self, vector):
    return self.multiply_fisher(vector)

  def multiply_hessian_factor(self, vector):
    return self.multiply_fisher_factor(vector)

  def multiply_hessian_factor_transpose(self, vector):
    return self.multiply_fisher_factor_transpose(vector)

  def multiply_hessian_factor_replicated_one_hot(self, index):
    return self.multiply_fisher_factor_replicated_one_hot(index)

  @property
  def hessian_factor_inner_shape(self):
    return self.fisher_factor_inner_shape

  @property
  def hessian_factor_inner_static_shape(self):
    return self.fisher_factor_inner_shape


class DistributionNegativeLogProbLoss(NegativeLogProbLoss):
  """Base class for neg log prob losses that use the TF Distribution classes."""

  def __init__(self, seed=None):
    super(DistributionNegativeLogProbLoss, self).__init__(seed=seed)

  @abc.abstractproperty
  def dist(self):
    """The underlying tf.distributions.Distribution."""
    pass

  def _evaluate(self, targets):
    return -math_ops.reduce_sum(self.dist.log_prob(targets))

  def sample(self, seed):
    return self.dist.sample(seed=seed)


class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,
                                    NaturalParamsNegativeLogProbLoss):
  """Neg log prob loss for a normal distribution parameterized by a mean vector.


  Note that the covariance is treated as a constant 'var' times the identity.
  Also note that the Fisher for such a normal distribution with respect the mean
  parameter is given by:

     F = (1/var) * I

  See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.
  """

  def __init__(self, mean, var=0.5, targets=None, seed=None):
    self._mean = mean
    self._var = var
    self._targets = targets
    super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed)

  @property
  def targets(self):
    return self._targets

  @property
  def dist(self):
    return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var))

  @property
  def params(self):
    return self._mean

  def multiply_fisher(self, vector):
    return (1. / self._var) * vector

  def multiply_fisher_factor(self, vector):
    return self._var**-0.5 * vector

  def multiply_fisher_factor_transpose(self, vector):
    return self.multiply_fisher_factor(vector)  # it's symmetric in this case

  def multiply_fisher_factor_replicated_one_hot(self, index):
    assert len(index) == 1, "Length of index was {}".format(len(index))
    ones_slice = array_ops.expand_dims(
        array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype),
        axis=-1)
    output_slice = self._var**-0.5 * ones_slice
    return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]),
                                 index[0])

  @property
  def fisher_factor_inner_shape(self):
    return array_ops.shape(self._mean)

  @property
  def fisher_factor_inner_static_shape(self):
    return self._mean.shape


class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
  """Negative log prob loss for a normal distribution with mean and variance.

  This class parameterizes a multivariate normal distribution with n independent
  dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not
  assume the variance is held constant. The Fisher Information for n = 1
  is given by,

  F = [[1 / variance,                0],
       [           0, 0.5 / variance^2]]

  where the parameters of the distribution are concatenated into a single
  vector as [mean, variance]. For n > 1, the mean parameter vector is
  concatenated with the variance parameter vector.

  See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation.
  """

  def __init__(self, mean, variance, targets=None, seed=None):
    assert len(mean.shape) == 2, "Expect 2D mean tensor."
    assert len(variance.shape) == 2, "Expect 2D variance tensor."
    self._mean = mean
    self._variance = variance
    self._scale = math_ops.sqrt(variance)
    self._targets = targets
    super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed)

  @property
  def targets(self):
    return self._targets

  @property
  def dist(self):
    return normal.Normal(loc=self._mean, scale=self._scale)

  @property
  def params(self):
    return self._mean, self._variance

  def _concat(self, mean, variance):
    return array_ops.concat([mean, variance], axis=-1)

  def _split(self, params):
    return array_ops.split(params, 2, axis=-1)

  @property
  def _fisher_mean(self):
    return 1. / self._variance

  @property
  def _fisher_mean_factor(self):
    return 1. / self._scale

  @property
  def _fisher_var(self):
    return 1. / (2 * math_ops.square(self._variance))

  @property
  def _fisher_var_factor(self):
    return 1. / (math_ops.sqrt(2.) * self._variance)

  def multiply_fisher(self, vecs):
    mean_vec, var_vec = vecs
    return (self._fisher_mean * mean_vec, self._fisher_var * var_vec)

  def multiply_fisher_factor(self, vecs):
    mean_vec, var_vec = self._split(vecs)
    return (self._fisher_mean_factor * mean_vec,
            self._fisher_var_factor * var_vec)

  def multiply_fisher_factor_transpose(self, vecs):
    mean_vec, var_vec = vecs
    return self._concat(self._fisher_mean_factor * mean_vec,
                        self._fisher_var_factor * var_vec)

  def multiply_fisher_factor_replicated_one_hot(self, index):
    assert len(index) == 1, "Length of index was {}".format(len(index))
    index = index[0]

    if index < int(self._mean.shape[-1]):
      # Index corresponds to mean parameter.
      mean_slice = self._fisher_mean_factor[:, index]
      mean_slice = array_ops.expand_dims(mean_slice, axis=-1)
      mean_output = insert_slice_in_zeros(mean_slice, 1, int(
          self._mean.shape[1]), index)
      var_output = array_ops.zeros_like(mean_output)
    else:
      index -= int(self._mean.shape[-1])
      # Index corresponds to variance parameter.
      var_slice = self._fisher_var_factor[:, index]
      var_slice = array_ops.expand_dims(var_slice, axis=-1)
      var_output = insert_slice_in_zeros(var_slice, 1,
                                         int(self._variance.shape[1]), index)
      mean_output = array_ops.zeros_like(var_output)

    return mean_output, var_output

  @property
  def fisher_factor_inner_shape(self):
    return array_ops.concat(
        [
            array_ops.shape(self._mean)[:-1],
            2 * array_ops.shape(self._mean)[-1:]
        ],
        axis=0)

  @property
  def fisher_factor_inner_static_shape(self):
    shape = self._mean.shape.as_list()
    return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]])

  def multiply_hessian(self, vector):
    raise NotImplementedError()

  def multiply_hessian_factor(self, vector):
    raise NotImplementedError()

  def multiply_hessian_factor_transpose(self, vector):
    raise NotImplementedError()

  def multiply_hessian_factor_replicated_one_hot(self, index):
    raise NotImplementedError()

  @property
  def hessian_factor_inner_shape(self):
    raise NotImplementedError()

  @property
  def hessian_factor_inner_static_shape(self):
    raise NotImplementedError()


class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
                                           NaturalParamsNegativeLogProbLoss):
  """Neg log prob loss for a categorical distribution parameterized by logits.


  Note that the Fisher (for a single case) of a categorical distribution, with
  respect to the natural parameters (i.e. the logits), is given by:

  F = diag(p) - p*p^T

  where p = softmax(logits).  F can be factorized as F = B * B^T where

  B = diag(q) - p*q^T

  where q is the entry-wise square root of p. This is easy to verify using the
  fact that q^T*q = 1.
  """

  def __init__(self, logits, targets=None, seed=None):
    """Instantiates a CategoricalLogitsNegativeLogProbLoss.

    Args:
      logits: Tensor of shape [batch_size, output_size]. Parameters for
        underlying distribution.
      targets: None or Tensor of shape [output_size]. Each elements contains an
        index in [0, output_size).
      seed: int or None. Default random seed when sampling.
    """
    self._logits_components = []
    self._targets_components = []
    self.register_additional_minibatch(logits, targets=targets)
    super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed)

  def register_additional_minibatch(self, logits, targets=None):
    """Register an additiona minibatch's worth of parameters.

    Args:
      logits: Tensor of shape [batch_size, output_size]. Parameters for
        underlying distribution.
      targets: None or Tensor of shape [batch_size, output_size].  Each row must
        be a one-hot vector.
    """
    self._logits_components.append(logits)
    self._targets_components.append(targets)

  @property
  def _logits(self):
    return array_ops.concat(self._logits_components, axis=0)

  @property
  def input_minibatches(self):
    return self._logits_components

  @property
  def targets(self):
    if all(target is None for target in self._targets_components):
      return None
    return array_ops.concat(self._targets_components, axis=0)

  @property
  def dist(self):
    return categorical.Categorical(logits=self._logits)

  @property
  def _probs(self):
    return self.dist.probs

  @property
  def _sqrt_probs(self):
    return math_ops.sqrt(self._probs)

  @property
  def params(self):
    return self._logits

  def multiply_fisher(self, vector):
    probs = self._probs
    return vector * probs - math_ops.reduce_sum(vector * probs, axis=1) * probs

  def multiply_fisher_factor(self, vector):
    probs = self._probs
    sqrt_probs = self._sqrt_probs
    return sqrt_probs * vector - probs * math_ops.reduce_sum(
        sqrt_probs * vector, axis=1, keep_dims=True)

  def multiply_fisher_factor_transpose(self, vector):
    probs = self._probs
    sqrt_probs = self._sqrt_probs
    return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum(
        probs * vector, axis=1, keep_dims=True)

  def multiply_fisher_factor_replicated_one_hot(self, index):
    assert len(index) == 1, "Length of index was {}".format(len(index))
    probs = self._probs
    sqrt_probs = self._sqrt_probs
    sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1)
    padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1,
                                         int(sqrt_probs.shape[1]), index[0])
    return padded_slice - probs * sqrt_probs_slice

  @property
  def fisher_factor_inner_shape(self):
    return array_ops.shape(self._logits)

  @property
  def fisher_factor_inner_static_shape(self):
    return self._logits.shape


class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss,
                                        NaturalParamsNegativeLogProbLoss):
  """Neg log prob loss for multiple Bernoulli distributions param'd by logits.

  Represents N independent Bernoulli distributions where N = len(logits). Its
  Fisher Information matrix is given by,

  F = diag(p * (1-p))
  p = sigmoid(logits)

  As F is diagonal with positive entries, its factor B is,

  B = diag(sqrt(p * (1-p)))
  """

  def __init__(self, logits, targets=None, seed=None):
    self._logits = logits
    self._targets = targets
    super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed)

  @property
  def targets(self):
    return self._targets

  @property
  def dist(self):
    return bernoulli.Bernoulli(logits=self._logits)

  @property
  def _probs(self):
    return self.dist.probs

  @property
  def params(self):
    return self._logits

  def multiply_fisher(self, vector):
    return self._probs * (1 - self._probs) * vector

  def multiply_fisher_factor(self, vector):
    return math_ops.sqrt(self._probs * (1 - self._probs)) * vector

  def multiply_fisher_factor_transpose(self, vector):
    return self.multiply_fisher_factor(vector)  # it's symmetric in this case

  def multiply_fisher_factor_replicated_one_hot(self, index):
    assert len(index) == 1, "Length of index was {}".format(len(index))
    probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1)
    output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice))
    return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]),
                                 index[0])

  @property
  def fisher_factor_inner_shape(self):
    return array_ops.shape(self._logits)

  @property
  def fisher_factor_inner_static_shape(self):
    return self._logits.shape


def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):
  """Inserts slice into a larger tensor of zeros.

  Forms a new tensor which is the same shape as slice_to_insert, except that
  the dimension given by 'dim' is expanded to the size given by 'dim_size'.
  'position' determines the position (index) at which to insert the slice within
  that dimension.

  Assumes slice_to_insert.shape[dim] = 1.

  Args:
    slice_to_insert: The slice to insert.
    dim: The dimension which to expand with zeros.
    dim_size: The new size of the 'dim' dimension.
    position: The position of 'slice_to_insert' in the new tensor.

  Returns:
    The new tensor.

  Raises:
    ValueError: If the slice's shape at the given dim is not 1.
  """
  slice_shape = slice_to_insert.shape
  if slice_shape[dim] != 1:
    raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but "
                     "was {}".format(dim, slice_to_insert.shape[dim]))

  before = [0] * int(len(slice_shape))
  after = before[:]
  before[dim] = position
  after[dim] = dim_size - position - 1

  return array_ops.pad(slice_to_insert, list(zip(before, after)))
